# -----------------------
# Benchmark simulation
# -----------------------
setwd("..")
source("../FUNCTIONS.R")  
outpath =".."
Rcpp::sourceCpp('../FUNCTIONS_CPP_JGNSC.cpp')

library(pheatmap)
library(umap)
library(huge)
library(JGL)
library(PRROC)

set.seed(simnum)

count1 <- CountMap5(sigma = sigma.list.1[[1]], ngene = 100, n = nsample)
count2 <- CountMap5(sigma = sigma.list.1[[2]], ngene = 100, n = nsample)

count.imp1 <- JGNsc.cont2(y = t(count1$count), dropThreshold = 0.75, min.cell = 1, warm = 200, iter = 500)
count.imp2 <- JGNsc.cont2(y = t(count2$count), dropThreshold = 0.75, min.cell = 1, warm = 200, iter = 500) 

countlist <- list(count1, count2)
implist <- list(count.imp1, count.imp2)

observed.list <- lapply(list(countlist[[1]]$count,
                             countlist[[2]]$count),t)

# MCIMP
# data: samples by genes 
 

# iter imp
impiter <- JGNsc_iterimp(observed.list = observed.list, imputedList = implist, mask.rate = mrate / 100)
theta.star.npn <- lapply(impiter, huge.npn) 
 
# ------------------------
# Gaussian Transformation
# ------------------------

observedGauList <- list()
JGNscGauList <- list()
NodropCountGauList <- list()
# mcimpCountGauList <- list()
for (k in 1:length(countlist)){
  observedGauList[[k]] <- huge.npn(countlist[[k]]$count)
  JGNscGauList[[k]] <- huge.npn(t(implist[[k]]$y.impute))
  NodropCountGauList[[k]] <- huge.npn(countlist[[k]]$count.nodrop)
  # mcimpCountGauList[[k]] <- huge.npn(mcimp.list[[k]])
}


# -------------------------------
# jgl results
# N BY P
# AIC selection step?

getPartcorrViaTPSelect <- function(inputgau, l1.vec, l2.vec){
  if (is.null(l1.vec)){
    l1.vec <- seq(1,30,by=2)/100
  }
  if (is.null(l2.vec)){
    l2.vec <- seq(1,20,by=2)/100
  }
  tuningparam <- NULL
  for (lam1 in l1.vec){
    for(lam2 in l2.vec){
      cat("tuning parameter:", lam1,", ", lam2,"\n")
      # tps <- AIC_select(mat.list.t.gau = input.npn, lam1 = lam1, lam2 = lam2)
      tps <- tuning_select(inputgau, lam1 = lam1, lam2 = lam2)
      tuningparam <- rbind(tuningparam, tps[[1]])
    }
  }
  colnames(tuningparam) <- c("lam1", "lam2", "aic", "bic", "ebic", "ebic1", "ebic2") 
  JGL.aic <- JGL(inputgau, lambda1 = tuningparam[which.min(tuningparam[,3]),1],
                 lambda2 = tuningparam[which.min(tuningparam[,3]),2], return.whole.theta = T)
  partcorr.aic <- lapply(JGL.aic$theta, prec2partialcorr)
  JGL.bic <- JGL(inputgau, lambda1 = tuningparam[which.min(tuningparam[,4]),1],
                 lambda2 = tuningparam[which.min(tuningparam[,4]),2], return.whole.theta = T)
  partcorr.bic <- lapply(JGL.bic$theta, prec2partialcorr)
  JGL.ebic <- JGL(inputgau, lambda1 = tuningparam[which.min(tuningparam[,5]),1],
                  lambda2 = tuningparam[which.min(tuningparam[,5]),2], return.whole.theta = T)
  partcorr.ebic <- lapply(JGL.ebic$theta, prec2partialcorr)
  res.partcor <- list(partcorr.aic = partcorr.aic,
                      partcorr.bic = partcorr.bic,
                      partcorr.ebic = partcorr.ebic,
                      tuningparam = tuningparam)
  return(res.partcor)
}

plist.aic <- list()
plist.bic <- list()
plist.ebic <- list()
plist.stars <- list()
stars.table <- NULL
tp.table <- NULL

methods.list <- list(NodropCountGauList,   
                     theta.star.npn ) 
methods <- c("NoDropout",
             "JGNsc + iter" )
methods.order <- methods


for (mm in 1:length(methods.list)){
  tp.temp <- getPartcorrViaTPSelect(methods.list[[mm]], l1.vec = seq(1,35, by=2)/100, l2.vec = c(1,3,5,7,9,11,13,15)/100)
  stars.temp <- stars.select(methods.list[[mm]], lambvec = seq(1,35, by=2)/100)
  # stars.observed$stars.l1$opt.lambda.correction
  # stars.observed$stars.l2$opt.lambda.correction
  # partcorr.stars <- stars.observed$partcorr
  plist.aic[[mm]] <- tp.temp$partcorr.aic
  plist.bic[[mm]] <- tp.temp$partcorr.bic
  plist.ebic[[mm]] <- tp.temp$partcorr.ebic
  plist.stars[[mm]] <- stars.temp$partcorr
  tp.table <- rbind(tp.table, cbind.data.frame(tp.temp$tuningparam, methods[mm]))
  stars.table <- rbind(stars.table, c(stars.temp$stars.l1$opt.lambda.correction, 
                                      stars.temp$stars.l2$opt.lambda.correction, methods[mm]))
}


# -------------------------------
# JGL to partcorr
partcorr.true <- lapply(list(countlist[[1]]$precision, countlist[[2]]$precision), prec2partialcorr)
# true adjacency matrix
partcorr.true.trunc <- lapply(partcorr.true, trunc_precision)
trueadj.list <- lapply(partcorr.true.trunc, function(x){
  y = abs(sign(trunc_precision(x))) 
  return(y)
}) 

eval.table.aic <- eval_partcorrList(plist = plist.aic, partcorr.true.trunc = partcorr.true.trunc, 
                                    methods = methods, methods.order = methods.order)
eval.table.bic <- eval_partcorrList(plist = plist.bic, partcorr.true.trunc = partcorr.true.trunc, 
                                    methods = methods, methods.order = methods.order)
eval.table.ebic <- eval_partcorrList(plist = plist.ebic, partcorr.true.trunc = partcorr.true.trunc, 
                                     methods = methods, methods.order = methods.order)
eval.table.stars <- eval_partcorrList(plist = plist.stars, partcorr.true.trunc = partcorr.true.trunc, 
                                      methods = methods, methods.order = methods.order)

tpnames <- c("AIC","BIC","EBIC","STARS")
tpeval <- list(eval.table.aic,
               eval.table.bic,
               eval.table.ebic,
               eval.table.stars)

for (tt in 1:length(tpnames)){
  TP = tpnames[tt]
  evaltable = tpeval[[tt]]
  write.table(evaltable$sse,     paste(outpath,"/D0002mask_",TP,"_SSE_simnum_nsample_mrate.txt", sep = ""), quote = F, sep = "\t", row.names = F, col.names = F)
  write.table(evaltable$pearson, paste(outpath,"/D0002mask_",TP,"_pcor_simnum_nsample_mrate.txt", sep = ""), quote = F, sep = "\t", row.names = F, col.names = F)
  write.table(evaltable$auc,     paste(outpath,"/D0002mask_",TP,"_auc_simnum_nsample_mrate.txt", sep = ""), quote = F, sep = "\t", row.names = F, col.names = F)
  write.table(evaltable$prc,     paste(outpath,"/D0002mask_",TP,"_auprc_simnum_nsample_mrate.txt", sep = ""), quote = F, sep = "\t", row.names = F, col.names = F)
}
